- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
Allow Truncation of CustomDist #6947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dc55b05    to
    a47bb8a      
    Compare
  
    | Codecov ReportAttention: Patch coverage is  
 
 Additional details and impacted files@@            Coverage Diff             @@
##             main    #6947      +/-   ##
==========================================
- Coverage   92.30%   87.75%   -4.55%     
==========================================
  Files         100      100              
  Lines       16888    16958      +70     
==========================================
- Hits        15588    14882     -706     
- Misses       1300     2076     +776     
 | 
a47bb8a    to
    8df78a6      
    Compare
  
    8df78a6    to
    d9f0067      
    Compare
  
    d9f0067    to
    dca7c49      
    Compare
  
    e7793af    to
    374e4e3      
    Compare
  
    374e4e3    to
    92efe38      
    Compare
  
    | This PR depends on #7227 | 
92efe38    to
    da60f55      
    Compare
  
    | Looks like some changes still needed till tests pass. | 
| 
 See my comment above, it needs the pytensor dependency bump which is happening in a separate PR | 
| Oh ok. I know it's not the most essential for this PR, but why does the
Scan require a shared variable as an argument?… On Thu, 28 Mar 2024, 22:35 Ricardo Vieira, ***@***.***> wrote:
 Looks like some changes still needed till tests pass.
 See my comment above, it needs the pytensor dependency bump which is
 happening in a separate PR
 —
 Reply to this email directly, view it on GitHub
 <#6947 (comment)>, or
 unsubscribe
 <https://github.com/notifications/unsubscribe-auth/AAACCUMISXKP3U4BVTYFK3LY2R5JXAVCNFSM6AAAAAA537NGBKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRWGE3TCMRVGM>
 .
 You are receiving this because your review was requested.Message ID:
 ***@***.***>
 | 
| 
 It's a limitation in the original implementation of Scan, where RNG variables must be shared. It's one of the things that we are hoping to solve with pymc-devs/pytensor#191 | 
da60f55    to
    34a3444      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
64aa904    to
    48a5c39      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment: What do you think about a warning if the Truncated distribution has to fall back to rejection sampling? This will introduce a while scan into the graph that could be quite surprising to users (JAX mode no longer possible, potentially big performance hit)
48a5c39    to
    397d113      
    Compare
  
    | 
 I wouldn't add a warning, because there is nothing for the user to do instead. Can add a note in the docstrings if there's no mention of it yet | 
397d113    to
    69622d4      
    Compare
  
    
This is now possible:
This required cleaning up the interface of
SymbolicRandomVariables (mainly circumventing pymc-devs/pytensor#473) so that we can safely "box" the base RVs in the innerOpFromGraph(i.e., recreate them with new shared inputs).This challenge is very specific to
Truncatedwhich needs to "resample" the base RV for the rejection based algorithm.No other
SymbolicRandomVariableneeds to do this, and they have avoided the need to box the base RVs by simply resizing them to the total size and using the resized RVs as explicit inputs to the inner graph.For instance,
Mixturewill resize the component RVs to the "total size" and then scholastically index them based on its internal Categorical RV. ZeroSumNormal will create Normals as inputs and simply subtraction the mean.Such an approach, however, makes it tricky for Truncated to know exactly what constitutes the "true" inputs of underlying
SymbolicRandomVariables, and for this reason it rejected and still rejects arbitrary SymbolicRandomVariables. The exception, are theSymbolicRandomVariables created viaCustomDistbecause for those are already "pre-boxed" in a sense. We know the relevant graph must start atdist.owner.inputs. Now that our class can safely manage and replace shared RNGs inputs, we can allow Truncated to handle such RVs, even if they require a bunch of shared RNGs.Related to #6905 (comment)
TODO